#!/usr/bin/env python3
# =============================================================================
# @file    compare.py
# @brief   Compare splitter results to a given oracle
# @author  Michael Hucka <mhucka@caltech.edu>
# @license Please see the file named LICENSE in the project directory
# @website https://github.com/casics/spiral
# =============================================================================

import os
import plac
import sys
from   time import time

thisdir = os.path.dirname(os.path.abspath(__file__))
try:
    sys.path.append(os.path.join(thisdir, ".."))
    sys.path.append(os.path.join(thisdir, "../.."))
except:
    sys.path.append("..")
    sys.path.append("../..")

cases = {}

@plac.annotations(
    lowercase = ('lowercase the results before comparison', 'flag',   'l'),
    summarize = ('summarize results only',                  'flag',   's'),
    splitter  = ('splitter to use: "ronin", "samurai"',     'option', 'u'),
)

def main(path, lowercase=False, summarize=False, splitter='ronin'):
    '''Compare the results of Spiral splitters against an oracle.  This
function requires one argument, a path to a file containing the oracle data
in tab-separated (TSV) form with two columns:

   identifer	each,token,split,separated,by,commas

The flag -l makes the split results be put in lower case before being compared
to the expect split (i.e., the second column in the TSV file).  Lower case
results are used in some oracles and not others.

Example of use:

   ./compare.py ../data/ludiso.tsv
'''
    if splitter == 'ronin':
        from spiral import ronin
        splitter_init = ronin.init
        splitter_split = ronin.split
    elif splitter == 'samurai':
        from spiral import samurai
        splitter_init = samurai.init
        splitter_split = samurai.split

    with open(path, 'r') as inputfile:
        for line in inputfile:
            (id, expected) = line.rstrip().split('\t')
            cases[id] = expected.split(',')
    successes = 0
    failures = 0
    count = 0

    for identifier, expected in sorted(cases.items()):
        result = splitter_split(identifier)
        if lowercase and result:
            result = [x.lower() for x in result]
        if result == expected:
            successes += 1
        else:
            failures += 1
            if not summarize:
                print('failed {}: expected {} got {}'
                      .format(identifier, expected, result))
    print('\n{} successes, {} failures => {:.02f}% accuracy'
          .format(successes, failures, 100*successes/(successes + failures)))


if __name__ == '__main__':
    plac.call(main)
